{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Predicting the outcome of a US presidential election using Bayesian optimal experimental design\n",
    "\n",
    "In this tutorial, we explore the use of optimal experimental design techniques to create an optimal polling strategy to predict the outcome of a US presidential election. In a [previous tutorial](http://pyro.ai/examples/working_memory.html), we explored the use of Bayesian optimal experimental design to learn the working memory capacity of a single person. Here, we apply the same concepts to study a whole country.\n",
    "\n",
    "To begin, we need a Bayesian model of the winner of the election `w`, as well as the outcome `y` of any poll we may plan to conduct. The experimental design is the number of people $n_i$ to poll in each state. To set up our exploratory model, we are going to make a number of simplifying assumptions. We will use historical election data 1976-2012 to construct a plausible prior and the 2016 election as our test set: we imagine that we are conducting polling just before the 2016 election.\n",
    "\n",
    "## Choosing a prior\n",
    "In our model, we include a 51 dimensional latent variabe `alpha`. For each of the 50 states plus DC we define \n",
    "\n",
    "$$ \\alpha_i = \\text{logit }\\mathbb{P}(\\text{a random voter in state } i \\text{ votes Democrat in the 2016 election}) $$\n",
    "\n",
    "and we assume all other voters vote Republican. Right before the election, the value of $\\alpha$ is unknown and we wish to estimate it by conducting a poll with $n_i$ people in state $i$ for $i=1, ..., 51$ . The winner $w$ of the election is decided by the Electoral College system. The number of electoral college votes gained by the Democrats in state $i$ is\n",
    "$$",
    "e_i =  \\begin{cases}\n",
    "k_i \\text{ if } \\alpha_i > \\frac{1}{2} \\\\\n",
    "0 \\text{ otherwise}\n",
    "\\end{cases}\n",
    "$$",
    "(this is a rough approximation of the true system). All other electoral college votes go to the Republicans. Here $k_i$ is the number of electoral college votes alloted to state $i$, which are listed in the following data frame."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data path\n",
    "BASE_URL =  \"https://d2hg8soec8ck9v.cloudfront.net/datasets/us_elections/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "       Electoral college votes\n",
      "State                         \n",
      "AL                           9\n",
      "AK                           3\n",
      "AZ                          11\n",
      "AR                           6\n",
      "CA                          55\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import torch\n",
    "from urllib.request import urlopen\n",
    "\n",
    "electoral_college_votes = pd.read_pickle(urlopen(BASE_URL + \"electoral_college_votes.pickle\"))\n",
    "print(electoral_college_votes.head())\n",
    "ec_votes_tensor = torch.tensor(electoral_college_votes.values, dtype=torch.float).squeeze()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The winner $w$ of the election is\n",
    "\n",
    "$$ w = \\begin{cases}\n",
    "\\text{Democrats if } \\sum_i e_i > \\frac{1}{2}\\sum_i k_i  \\\\\n",
    "\\text{Republicans otherwise}\n",
    "\\end{cases}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In code, this is expressed as follows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def election_winner(alpha):\n",
    "    dem_win_state = (alpha > 0.).float()\n",
    "    dem_electoral_college_votes = ec_votes_tensor * dem_win_state\n",
    "    w = (dem_electoral_college_votes.sum(-1) / ec_votes_tensor.sum(-1) > .5).float()\n",
    "    return w"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We are interested in polling strategies that will help us predict $w$, rather than predicting the more complex state-by-state results $\\alpha$.\n",
    "\n",
    "To set up a fully Bayesian model, we need a prior for $\\alpha$. We will base the prior on the outcome of some historical presidential elections. Specifically, we'll use the following dataset of state-by-state election results for the presidential elections 1976-2012 inclusive. Note that votes for parties other than Democrats and Republicans have been ignored."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "          1976                1980                1984           \n",
      "      Democrat Republican Democrat Republican Democrat Republican\n",
      "State                                                            \n",
      "AL      659170     504070   636730     654192   551899     872849\n",
      "AK       44058      71555    41842      86112    62007     138377\n",
      "AZ      295602     418642   246843     529688   333854     681416\n",
      "AR      499614     268753   398041     403164   338646     534774\n",
      "CA     3742284    3882244  3083661    4524858  3922519    5467009\n"
     ]
    }
   ],
   "source": [
    "frame = pd.read_pickle(urlopen(BASE_URL + \"us_presidential_election_data_historical.pickle\"))\n",
    "print(frame[[1976, 1980, 1984]].head())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Based on this data alone, we will base our prior mean for $\\alpha$ solely on the 2012 election. Our model will be based on logistic regression, so we will transform the probability of voting Democrat using the logit function. Specifically, we'll choose a prior mean as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_2012 = torch.tensor(frame[2012].values, dtype=torch.float)\n",
    "prior_mean = torch.log(results_2012[..., 0] / results_2012[..., 1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our prior distribution for $\\alpha$ will be a multivariate Normal with mean `prior_mean`. The only thing left to decide upon is the covariance matrix. Since `alpha` values are logit-transformed, the covariance will be defined in logit space as well.\n",
    "\n",
    "*Aside*: The prior covariance is important in a number of ways. If we allow too much variance, the prior will be uncertain about the outcome in every state, and require polling everywhere. If we allow too little variance, we may be caught off-guard by an unexpected electoral outcome. If we assume states are independent, then we will not be able to pool information across states; but assume too much correlation and we could too faithfully base predictions about one state from poll results in another."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We select the prior covariance by taking the empirical covariance from the elections 1976 - 2012 and adding a small value `0.01` to the diagonal."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 2 * torch.arange(10)\n",
    "as_tensor = torch.tensor(frame.values, dtype=torch.float)\n",
    "logits = torch.log(as_tensor[..., idx] / as_tensor[..., idx + 1]).transpose(0, 1)\n",
    "mean = logits.mean(0)\n",
    "sample_covariance = (1/(logits.shape[0] - 1)) * (\n",
    "    (logits.unsqueeze(-1) - mean) * (logits.unsqueeze(-2) - mean)\n",
    ").sum(0)\n",
    "prior_covariance = sample_covariance + 0.01 * torch.eye(sample_covariance.shape[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting up the model\n",
    "We are now in a position to define our model. At a high-level the model works as follows:\n",
    " - $\\alpha$ is multivariate Normal\n",
    " - $w$ is a deterministic function of $\\alpha$\n",
    " - $y_i$ is Binomial($n_i$, sigmoid($\\alpha_i$)) so we are assuming that people respond to the poll in exactly the same way that they will vote on election day\n",
    " \n",
    "In Pyro, this model looks as follows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pyro\n",
    "import pyro.distributions as dist\n",
    "\n",
    "def model(polling_allocation):\n",
    "    # This allows us to run many copies of the model in parallel\n",
    "    with pyro.plate_stack(\"plate_stack\", polling_allocation.shape[:-1]):\n",
    "        # Begin by sampling alpha\n",
    "        alpha = pyro.sample(\"alpha\", dist.MultivariateNormal(\n",
    "            prior_mean, covariance_matrix=prior_covariance))\n",
    "        \n",
    "        # Sample y conditional on alpha\n",
    "        poll_results = pyro.sample(\"y\", dist.Binomial(\n",
    "            polling_allocation, logits=alpha).to_event(1))\n",
    "        \n",
    "        # Now compute w according to the (approximate) electoral college formula\n",
    "        dem_win = election_winner(alpha)\n",
    "        pyro.sample(\"w\", dist.Delta(dem_win))\n",
    "        \n",
    "        return poll_results, dem_win, alpha"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Understanding the prior\n",
    "\n",
    "Before we go any further, we're going to study the model to check it matches with our intuition about US presidential elections.\n",
    "\n",
    "First of all, let's look at an upper and lower confidence limit for the proportion of voters who will vote Democrat in each state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "       Lower confidence limit  Upper confidence limit\n",
      "State                                                \n",
      "AL                   0.272258                0.517586\n",
      "AK                   0.330472                0.529117\n",
      "AZ                   0.321011                0.593634\n",
      "AR                   0.214348                0.576079\n",
      "CA                   0.458618                0.756616\n"
     ]
    }
   ],
   "source": [
    "std = prior_covariance.diag().sqrt()\n",
    "ci = pd.DataFrame({\"State\": frame.index,\n",
    "                   \"Lower confidence limit\": torch.sigmoid(prior_mean - 1.96 * std), \n",
    "                   \"Upper confidence limit\": torch.sigmoid(prior_mean + 1.96 * std)}\n",
    "                 ).set_index(\"State\")\n",
    "print(ci.head())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The prior on $\\alpha$ implicitly defines our prior on `w`. We can investigate this prior by simulating many times from the prior."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prior probability of Dem win 0.6799200177192688\n"
     ]
    }
   ],
   "source": [
    "_, dem_wins, alpha_samples = model(torch.ones(100000, 51))\n",
    "prior_w_prob = dem_wins.float().mean()\n",
    "print(\"Prior probability of Dem win\", prior_w_prob.item())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since our prior is based on 2012 and the Democrats won in 2012, it makes sense that we would favour a Democrat win in 2016 (this is before we have seen *any* polling data or incorporated any other information)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also investigate which states, a priori, are most marginal."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "       Democrat win probability\n",
      "State                          \n",
      "FL                      0.52501\n",
      "NC                      0.42730\n",
      "NH                      0.61536\n",
      "OH                      0.61997\n",
      "VA                      0.63738\n"
     ]
    }
   ],
   "source": [
    "dem_prob = (alpha_samples > 0.).float().mean(0)\n",
    "marginal = torch.argsort((dem_prob - .5).abs()).numpy()\n",
    "prior_prob_dem = pd.DataFrame({\"State\": frame.index[marginal],\n",
    "                               \"Democrat win probability\": dem_prob.numpy()[marginal]}\n",
    "                             ).set_index('State')\n",
    "print(prior_prob_dem.head())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a sanity check, and seems to accord with our intuitions. Florida is frequently an important swing state and is top of our list of marginal states under the prior. We can also see states such as Pennsylvania and Wisconsin near the top of the list -- we know that these were instrumental in the 2016 election."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we take a closer look at our prior covariance. Specifically, we examine states that we expect to be more or less correlated. Let's begin by looking at states in New England"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "State        ME        VT        NH        MA        RI        CT\n",
      "State                                                            \n",
      "ME     1.000000  0.817323  0.857351  0.800276  0.822024  0.825383\n",
      "VT     0.817323  1.000000  0.834723  0.716342  0.754026  0.844140\n",
      "NH     0.857351  0.834723  1.000000  0.871370  0.803803  0.873496\n",
      "MA     0.800276  0.716342  0.871370  1.000000  0.813665  0.835148\n",
      "RI     0.822024  0.754026  0.803803  0.813665  1.000000  0.849644\n",
      "CT     0.825383  0.844140  0.873496  0.835148  0.849644  1.000000\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def correlation(cov):\n",
    "    return cov / np.sqrt(np.expand_dims(np.diag(cov.values), 0) * np.expand_dims(np.diag(cov.values), 1))\n",
    "                \n",
    "\n",
    "new_england_states = ['ME', 'VT', 'NH', 'MA', 'RI', 'CT']\n",
    "cov_as_frame = pd.DataFrame(prior_covariance.numpy(), columns=frame.index).set_index(frame.index)\n",
    "ne_cov = cov_as_frame.loc[new_england_states, new_england_states]\n",
    "ne_corr = correlation(ne_cov)\n",
    "print(ne_corr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Clearly, these states tend to vote similarly. We can also examine some states of the South which we also expect to be similar."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "State        LA        MS        AL        GA        SC\n",
      "State                                                  \n",
      "LA     1.000000  0.554020  0.651511  0.523784  0.517672\n",
      "MS     0.554020  1.000000  0.699459  0.784371  0.769198\n",
      "AL     0.651511  0.699459  1.000000  0.829908  0.723015\n",
      "GA     0.523784  0.784371  0.829908  1.000000  0.852818\n",
      "SC     0.517672  0.769199  0.723015  0.852818  1.000000\n"
     ]
    }
   ],
   "source": [
    "southern_states = ['LA', 'MS', 'AL', 'GA', 'SC']\n",
    "southern_cov = cov_as_frame.loc[southern_states, southern_states]\n",
    "southern_corr = correlation(southern_cov)\n",
    "print(southern_corr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These correlation matrices show that, as expected, logical groupings of states tend to have similar voting trends. We now look at the correlations *between* the groups (e.g. between Maine and Louisiana)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "State        LA        MS        AL        GA        SC\n",
      "State                                                  \n",
      "ME     0.329438  0.309352 -0.000534  0.122375  0.333679\n",
      "VT    -0.036079  0.009653 -0.366604 -0.202065  0.034438\n",
      "NH     0.234105  0.146826 -0.105781  0.008411  0.233084\n",
      "MA     0.338411  0.122257 -0.059107 -0.025730  0.182290\n",
      "RI     0.314088  0.188819 -0.066307 -0.022142  0.186955\n",
      "CT     0.139021  0.074646 -0.205797 -0.107684  0.125023\n"
     ]
    }
   ],
   "source": [
    "cross_cov = cov_as_frame.loc[new_england_states + southern_states, new_england_states + southern_states]\n",
    "cross_corr = correlation(cross_cov)\n",
    "print(cross_corr.loc[new_england_states, southern_states])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we see weaker correlation between New England states and Southern states than the correlation within those grouping. Again, this is as expected."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Measuring the expected information gain of a polling strategy\n",
    "The prior we have set up appears to accord, at least approximately, with intuition. However, we now want to add a second source of information from polling. We aim to use our prior to select a polling strategy that will be most informative about our target $w$. A polling strategy, in this simplified set-up, is the number of people to poll in each state. (We ignore any other covariates such as regional variation inside states, demographics, etc.) We might imagine that polling 1000 people in Florida (the most marginal state), will be much more effective than polling 1000 people in DC (the least marginal state). That's because the outcome in DC is already quite predictable, just based on our prior, whereas the outcome in Florida is really up for grabs.\n",
    "\n",
    "In fact, the information that our model will gain about $w$ based on conducting a poll with design $d$ and getting outcome $y$ can be described mathematically as follows:\n",
    "\n",
    "$$\\text{IG}(d, y) = KL(p(w|y,d)||p(w)).$$\n",
    "\n",
    "Since the outcome of the poll is at present unknown, we consider the expected information gain [1]\n",
    "\n",
    "$$\\text{EIG}(d) = \\mathbb{E}_{p(y|d)}[KL(p(w|y,d)||p(w))].$$\n",
    "\n",
    "### Variational estimators of EIG\n",
    "\n",
    "In the [working memory tutorial](http://pyro.ai/examples/working_memory.html), we used the 'marginal' estimator to find the EIG. This involved estimating the marginal density $p(y|d)$. In this experiment, that would be relatively difficult: $y$ is 51-dimensional with some rather tricky constraints that make modelling its density difficult. Furthermore, the marginal estimator requires us to know $p(y|w)$ analytically, which we do not.\n",
    "\n",
    "Fortunately, other variational estimators of EIG exist: see [2] for more details. One such variational estimator is the 'posterior' estimator, based on the following representation\n",
    "\n",
    "$$\\text{EIG}(d) = \\max_q \\mathbb{E}_{p(w, y|d)}\\left[\\log q(w|y) \\right] + H(p(w)).$$\n",
    "\n",
    "Here, $H(p(w))$ is the prior entropy on $w$ (we can compute this quite easily). The important term involves the variational approximation $q(w|y)$. This $q$ can be used to perform amortized variational inference. Specifically, it takes as input $y$ and outputs a distribution over $w$. The bound is maximised when $q(w|y) = p(w|y)$ [2]. Since $w$ is a binary random variable, we can think of $q$ as a classifier that tries to decide, based on the poll outcome, who the eventual winner of the election will be. In this notebook, $q$ will be a neural classifier. Training a neural classifier is a fair bit easier than learning the marginal density of $y$, so we adopt this method to estimate the EIG in this tutorial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "\n",
    "class OutcomePredictor(nn.Module):\n",
    "    \n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.h1 = nn.Linear(51, 64)\n",
    "        self.h2 = nn.Linear(64, 64)\n",
    "        self.h3 = nn.Linear(64, 1)\n",
    "        \n",
    "    def compute_dem_probability(self, y):\n",
    "        z = nn.functional.relu(self.h1(y))\n",
    "        z = nn.functional.relu(self.h2(z))\n",
    "        return self.h3(z)\n",
    "    \n",
    "    def forward(self, y_dict, design, observation_labels, target_labels):\n",
    "        \n",
    "        pyro.module(\"posterior_guide\", self)\n",
    "        \n",
    "        y = y_dict[\"y\"]\n",
    "        dem_prob = self.compute_dem_probability(y).squeeze()\n",
    "        pyro.sample(\"w\", dist.Bernoulli(logits=dem_prob))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll now use this to compute the EIG for several possible polling strategies. First, we need to compute the $H(p(w))$ term in the above formula."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "prior_entropy = dist.Bernoulli(prior_w_prob).entropy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's consider four simple polling strategies.\n",
    " 1. Poll 1000 people in Florida only\n",
    " 2. Poll 1000 people in DC only\n",
    " 3. Poll 1000 people spread evenly over the US\n",
    " 4. Using a polling allocation that focuses on swing states"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import OrderedDict\n",
    "\n",
    "poll_in_florida = torch.zeros(51)\n",
    "poll_in_florida[9] = 1000\n",
    "\n",
    "poll_in_dc = torch.zeros(51)\n",
    "poll_in_dc[8] = 1000\n",
    "\n",
    "uniform_poll = (1000 // 51) * torch.ones(51)\n",
    "\n",
    "# The swing score measures how close the state is to 50/50\n",
    "swing_score = 1. / (.5 - torch.tensor(prior_prob_dem.sort_values(\"State\").values).squeeze()).abs()\n",
    "swing_poll = 1000 * swing_score / swing_score.sum()\n",
    "swing_poll = swing_poll.round()\n",
    "\n",
    "poll_strategies = OrderedDict([(\"Florida\", poll_in_florida),\n",
    "                               (\"DC\", poll_in_dc),\n",
    "                               (\"Uniform\", uniform_poll),\n",
    "                               (\"Swing\", swing_poll)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll now compute the EIG for each option. Since this requires training the network (four times) it may take several minutes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Florida 0.3088182508945465\n",
      "DC 0.000973820686340332\n",
      "Uniform 0.30316317081451416\n",
      "Swing 0.3254041075706482\n"
     ]
    }
   ],
   "source": [
    "from pyro.contrib.oed.eig import posterior_eig\n",
    "from pyro.optim import Adam\n",
    "\n",
    "eigs = {}\n",
    "best_strategy, best_eig = None, 0\n",
    "\n",
    "for strategy, allocation in poll_strategies.items():\n",
    "    print(strategy, end=\" \")\n",
    "    guide = OutcomePredictor()\n",
    "    pyro.clear_param_store()\n",
    "    # To reduce noise when comparing designs, we will use the precomputed value of H(p(w))\n",
    "    # By passing eig=False, we tell Pyro not to estimate the prior entropy on each run\n",
    "    # The return value of `posterior_eig` is then -E_p(w,y)[log q(w|y)]\n",
    "    ape = posterior_eig(model, allocation, \"y\", \"w\", 10, 12500, guide, \n",
    "                        Adam({\"lr\": 0.001}), eig=False, final_num_samples=10000)\n",
    "    eigs[strategy] = prior_entropy - ape\n",
    "    print(eigs[strategy].item())\n",
    "    if eigs[strategy] > best_eig:\n",
    "        best_strategy, best_eig = strategy, eigs[strategy]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running the experiment\n",
    "\n",
    "We have now scored our four candidate designs and can choose the best one to use to actually gather new data. In this notebook, we will simulate the new data using the results from the 2016 election. Specifically, we will assume that the outcome of the poll comes from our model, where we condition the value of `alpha` to correspond to the actual results in 2016."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we retrain $q$ with the chosen polling strategy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.3653, grad_fn=<DivBackward0>)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "best_allocation = poll_strategies[best_strategy]\n",
    "pyro.clear_param_store()\n",
    "guide = OutcomePredictor()\n",
    "posterior_eig(model, best_allocation, \"y\", \"w\", 10, 12500, guide, \n",
    "              Adam({\"lr\": 0.001}), eig=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The value of $\\alpha$ implied by the 2016 results is computed in the same way we computed the prior."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data = pd.read_pickle(urlopen(BASE_URL + \"us_presidential_election_data_test.pickle\"))\n",
    "results_2016 = torch.tensor(test_data.values, dtype=torch.float)\n",
    "true_alpha = torch.log(results_2016[..., 0] / results_2016[..., 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "conditioned_model = pyro.condition(model, data={\"alpha\": true_alpha})\n",
    "y, _, _ = conditioned_model(best_allocation)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's view the outcome of our poll."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "       Number of people polled  Number who said they would vote Democrat\n",
      "State                                                                   \n",
      "FL                       191.0                                      89.0\n",
      "NE                        66.0                                      30.0\n",
      "NJ                        41.0                                      22.0\n",
      "OH                        40.0                                      19.0\n",
      "VT                        35.0                                      22.0\n"
     ]
    }
   ],
   "source": [
    "outcome = pd.DataFrame({\"State\": frame.index,\n",
    "                        \"Number of people polled\": best_allocation, \n",
    "                        \"Number who said they would vote Democrat\": y}\n",
    "                      ).set_index(\"State\")\n",
    "print(outcome.sort_values([\"Number of people polled\", \"State\"], ascending=[False, True]).head())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Analysing the data\n",
    "Having collected our data, we can now perform inference under our model to obtain the posterior probability of a Democrat win. There are many ways to perform the inference: for instance we could use variational inference with Pyro's `SVI` or MCMC such as Pyro's `NUTS`. Using these methods, we would compute the posterior over `alpha`, and use this to obtain the posterior over `w`.\n",
    "\n",
    "However, a quick way to analyze the data is to use the neural network we have already trained. At convergence, we expect the network to give a good approximation to the true posterior, i.e. $q(w|y) \\approx p(w|y)$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "q_w = torch.sigmoid(guide.compute_dem_probability(y).squeeze().detach())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prior probability of Democrat win 0.6799200177192688\n",
      "Posterior probability of Democrat win 0.40527692437171936\n"
     ]
    }
   ],
   "source": [
    "print(\"Prior probability of Democrat win\", prior_w_prob.item())\n",
    "print(\"Posterior probability of Democrat win\", q_w.item())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This completes the whole process of experimental design, obtaining data, and data analysis. By using our analysis model and prior to design the polling strategy, we were able to choose the design that led to the greatest expected gain in information."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusions\n",
    "\n",
    "In this tutorial, we showed how optimal experimental design principles can be applied to electoral polling. If we have the option of first designing the polling strategy using our prior, this can help lead to better predictions once we have collected our data.\n",
    "\n",
    "Our model could be enhanced in a number of ways: a useful guide to the kinds of models used in political science can be found in [3]. Our optimal design strategy will only ever be as good as our prior, so there is certainly a lot of use in investing time in choosing a good prior.\n",
    "\n",
    "It might also be possible to search over the possible polling strategies in a more sophisticated way. For instance, simulated annealing [4] is one technique that can be used for optimization in high-dimensional discrete spaces.\n",
    "\n",
    "Our polling strategies were chosen with one objective in mind: predict the final outcome $w$. If, on the other hand, we wanted to make more fine-grained predictions, we could use the same procedure but treat $\\alpha$ as our target instead of $w$.\n",
    "\n",
    "Finally, it's worth noting that there is a connection to Bayesian model selection: specifically, in a different model we might have $w$ a binary random variable to choose between different sub-models. Although conceptually different, our election set-up and the model selection set-up share many mathematical features which means that a similar experimental design approach could be used in both cases."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## References\n",
    "\n",
    "[1] Chaloner, K. and Verdinelli, I., 1995. **Bayesian experimental design: A review.** Statistical Science, pp.273-304.\n",
    "\n",
    "[2] Foster, A., Jankowiak, M., Bingham, E., Horsfall, P., Teh, Y.W., Rainforth, T. and Goodman, N., 2019. **Variational Bayesian Optimal Experimental Design.** Advances in Neural Information Processing Systems 2019 (to appear).\n",
    "\n",
    "[3] Gelman, A., Carlin, J.B., Stern, H.S., Dunson, D.B., Vehtari, A. and Rubin, D.B., 2013. **Bayesian data analysis.** Chapman and Hall/CRC.\n",
    "\n",
    "[4] Kirkpatrick, S., Gelatt, C.D. and Vecchi, M.P., 1983. **Optimization by simulated annealing.** Science, 220(4598), pp.671-680."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
